import os
import dmc_envs
import torch
import robel
import gym
import numpy as np
from utilis.config import ARGConfig
from utilis.default_config import default_config
from utilis.default_config import dmc_config, metaworld_config
from model.algorithm import CausalSAC
from utilis.Replaybuffer import ReplayMemory
from utilis.causal_weight import  get_sa2r_weight, get_sa2Q_weight, get_sa2r_opti_weight
import datetime
import itertools
from copy import copy
import shutil
import wandb
import csv
from metaworld_env import make_env



from torch.utils.tensorboard import SummaryWriter
import yaml
import ipdb
import metaworld
import random

def train_loop(config, msg = "default"):
    # set seed
    if config.env_name in metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.keys():
        env = make_env(config.env_name, config.seed, 500)
    else:
        env = gym.make(config.env_name)
        env.seed(config.seed)
        env.action_space.seed(config.seed)

    torch.manual_seed(config.seed)
    np.random.seed(config.seed)

    # Agent
    agent = CausalSAC(env.observation_space.shape[0], env.action_space, config)

    result_path = './results/{}/{}/{}_{}_{}_{}_{}'.format(config.env_name, msg, 
                                                      datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 
                                                      config.policy, config.seed, 
                                                      "autotune" if config.automatic_entropy_tuning else "",
                                                      config.msg)
    

    checkpoint_path = result_path + '/' + 'checkpoint'
    
    # training logs
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open(os.path.join(result_path, "config.log"), 'w') as f:
        f.write(str(config))
    
    #* Logging Causal weight
    causal_weight_csv_file = os.path.join(result_path, "causal_weight.csv")
    with open(causal_weight_csv_file, mode='w', newline='') as csv_file:
        fieldnames = ['Time Step', 'Causal Weights']
        writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
        writer.writeheader()

    # saving code
    current_path = os.path.dirname(os.path.abspath(__file__))
    files = os.listdir(current_path)
    files_to_save = ['main.py', 'model','utilis']
    ignore_files = [x for x in files if x not in files_to_save]
    shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns(*ignore_files))
    
    memory = ReplayMemory(config.replay_size, config.seed)
    local_buffer = ReplayMemory(config.causal_sample_size, config.seed)

    # Training Loop
    total_numsteps = 0
    updates = 0
    best_reward = -1e6
    best_success = 0.0
    causal_computing_time = 0.0
    causal_weight = np.ones(env.action_space.shape[0], dtype=np.float32)
    W_est = []
    for i_episode in itertools.count(1):
        episode_reward = 0
        episode_steps = 0
        done = False

        state = env.reset()
        while not done:
            if config.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state)  # Sample action from policy

            if len(memory) > config.batch_size:
                if config.algo == "CausalSAC":
                    for i in range(config.updates_per_step):
                        #* Update parameters of causal weight
                        if (total_numsteps % config.causal_sample_interval == 0) and (len(local_buffer)>=config.causal_sample_size):
                            if config.use_local_buffer:
                                if config.weight == "sa2r":
                                    causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                                elif config.weight == "sa2q":
                                    causal_weight, causal_computing_time = get_sa2Q_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                                elif config.weight == "optsa2r":
                                    causal_weight, causal_computing_time, W_est = get_sa2r_opti_weight(env, local_buffer, agent, W_est, sample_size=config.causal_sample_size, causal_method=config.causal_model)
                            else:
                                if config.weight == "sa2r":
                                    causal_weight, causal_computing_time = get_sa2r_weight(env, memory, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                                elif config.weight == "sa2q":
                                    causal_weight, causal_computing_time = get_sa2Q_weight(env, memory, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                                elif config.weight == "optsa2r":
                                    causal_weight, causal_computing_time, W_est = get_sa2r_opti_weight(env, local_buffer, agent, W_est, sample_size=config.causal_sample_size, causal_method=config.causal_model)
                            print("Current Causal Weight is: ",causal_weight)
                            wandb.log(
                                data={
                                    'Causal/Computing Time': causal_computing_time,
                                },
                                step = total_numsteps
                            )
                            with open(causal_weight_csv_file, mode='a', newline='') as csv_file:
                                csv_writer = csv.writer(csv_file)
                                csv_writer.writerow([total_numsteps, ', '.join(map(str, causal_weight))])
                            
                        # Update parameters of all the networks
                        critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac = agent.update_parameters(memory, causal_weight,config.batch_size, updates)

                        wandb.log(
                            data = {
                                'loss/q_critic_1': critic_1_loss,
                                'loss/q_critic_2': critic_2_loss,
                                'loss/policy_loss': policy_loss,
                                'loss/entropy_loss': ent_loss, 
                                'parameter/alpha': alpha,
                                'Q value comparison': {'q_exploration': q_sac},  
                            },
                            step = total_numsteps
                        )
                        
                        updates += 1
            if config.env_name in metaworld.ML1.ENV_NAMES:
                next_state, reward, done, info = env.step(action) # Step
            else:
                next_state, reward, done, info = env.step(action) # Step
            total_numsteps += 1
            episode_steps += 1
            episode_reward += reward

            #* Ignore the "done" signal if it comes from hitting the time horizon.
            if '_max_episode_steps' in dir(env):  #* panda, meta-world, kitty, mujoco
                mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            elif 'max_path_length' in dir(env):
                mask = 1 if episode_steps == env.max_path_length else float(not done)
            else: #* dmc
                mask = 1 if episode_steps == 1000 else float(not done)

            memory.push(state, action, reward, next_state, mask) # Append transition to memory
            local_buffer.push(state, action, reward, next_state, mask) # Append transition to local_buffer
            state = next_state

        if total_numsteps > config.num_steps:
            break


        wandb.log(
            data={
                'reward/train_reward': episode_reward
            },
            step = total_numsteps
        )
        print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))

        # test agent
        
        if i_episode % config.eval_interval == 0 and config.eval is True:
            eval_reward_list = []
            eval_success_list = []
            eval_step_length_list = []
            for _  in range(config.eval_episodes):
                state = env.reset()
                episode_reward = []
                done = False
                first_success_time = 0
                success = False
                while not done:
                    action = agent.select_action(state, evaluate=True)
                    next_state, reward, done, info = env.step(action)
                    
                    state = next_state
                    episode_reward.append(reward)
                    if 'success' in info.keys():
                        success |= bool(info["success"])
                        if not success:
                            first_success_time +=1
                    elif 'is_success' in info.keys():
                        success |= bool(info["is_success"])
                        if not success:
                            first_success_time +=1
                eval_reward_list.append(sum(episode_reward))
                eval_step_length_list.append(first_success_time)
                if 'score/success' in info.keys():
                    eval_success_list.append(float(info['score/success']))
                elif 'is_success' in info.keys():
                    eval_success_list.append(float(info['is_success']))
                elif 'success' in info.keys():
                    eval_success_list.append(success)
            avg_reward = np.average(eval_reward_list)
            avg_success = np.average(eval_success_list)
            avg_step_length = np.average(eval_step_length_list)
            
            if config.save_checkpoint == True:
                if config.env_name in metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.keys(): #* For MetaWorld
                    if avg_reward >= best_reward:
                        best_reward = avg_reward
                        if avg_success >= best_success:
                            best_success = avg_success
                        TODO: agent.save_checkpoint(checkpoint_path, 'best')
                        # agent.save_checkpoint(checkpoint_path, i_episode)
                else:
                    if avg_reward >= best_reward:
                        best_reward = avg_reward
                        agent.save_checkpoint(checkpoint_path, 'best')



            wandb.log(
                data = {
                    'reward/test_avg_reward': avg_reward,
                    'reward/success_rate': avg_success,
                    'reward/episode_length': avg_step_length
                },
                step = total_numsteps
            )
            
            print("----------------------------------------")
            print("Env: {}, Algo:{},  Test Episodes: {}, Avg. Reward: {}, Avg. Success: {}".format(config.env_name, config.algo, config.eval_episodes, round(avg_reward, 2), round(avg_success, 2)))
            print("----------------------------------------")
    
    save_buffer = False
    if save_buffer:
        memory_save_path = os.path.join(result_path, "buffer_%d.pickle"%total_numsteps)
        memory.save_buffer(memory_save_path)
    env.close() 


def main():
    arg = ARGConfig()
    arg.add_arg("env_name", "hammer-v2-goal-observable", "Environment name")
    arg.add_arg("device", "3", "Computing device")
    arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)")
    arg.add_arg("tag", "default", "Experiment tag")
    arg.add_arg("algo", "CausalSAC", "choose algorithm (OPT-Q, SAC, TD3, OPT-TD3)")
    arg.add_arg("start_steps", 10000, "Number of start steps")
    arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: False)")
    arg.add_arg("seed", 123457, "experiment seed")
    arg.add_arg("des", "", "short description for the experiment")
    arg.add_arg("num_steps", 1000001, "total number of steps")
    arg.add_arg("save_checkpoint", False, "save checkpoint or not")
    arg.add_arg("replay_size", 1000000, "size of replay buffer")
    arg.add_arg("weight", "sa2r", "type_of_weight") #* sa2r, optsa2r, sa2q
    arg.add_arg("use_local_buffer", True, "use_local_buffer_or_not")
    arg.add_arg("causal_sample_interval", 10000, "sample_size for causal computing")
    arg.add_arg("causal_sample_size", 10000, "sample_size for causal computing")
    arg.add_arg("causal_model","DagmaNonlinear", "causal model type") #* DagmaNonlinear, DagmaLinear, DirectLiNGAM
    arg.parser()

    mujoco_envs_list = ["Walker2d-v2","Hopper-v2", "Ant-v2", "Humanoid-v2", "Swimmer-v2", "HumanoidStandup-v2", "HalfCheetah-v2"]
    #* config file
    if arg.env_name in metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.keys():
        config = metaworld_config  
    elif arg.env_name in mujoco_envs_list:
        config = default_config
    else:
        config = dmc_config
    
    if not (arg.weight).startswith("opt"):
        config.causal_model = 'DirectLiNGAM'
    config.update(arg)


    algorithm = config.algo
    if config["seed"] == 123457:
        config["seed"] = np.random.randint(1000)
    

    experiment_name = "{}-{}-{}-w{}-l{}-s{}-{}".format(
        algorithm, 
        config['env_name'], 
        str(config["seed"]), 
        config["weight"],
        config["use_local_buffer"],
        config["causal_sample_interval"],
        config["causal_sample_size"],
    )
    
    run_id = "{}_{}_{}_{}_l{}_{}{}-{}_{}".format(
        algorithm, 
        config['env_name'],
        str(config["seed"]), 
        config["weight"],
        config["use_local_buffer"],
        config["causal_model"],
        config["causal_sample_interval"],
        config["causal_sample_size"],
        datetime.datetime.now().strftime("%Y-%m-%d_%H")
    )


    run = wandb.init(
        project = config["project_name"],
        config = {
            "env_name": config['env_name'],
            "automatic_entropy_tuning": config["automatic_entropy_tuning"],
            "algorithm" : algorithm,
            "seed": config["seed"],
            "weight": config["weight"],
            "use_local_buffer": config["use_local_buffer"],
            "causal_sample_interval": config["causal_sample_interval"],
            "causal_sample_size": config["causal_sample_size"],
            "num_steps": config["num_steps"]
        },
        name = experiment_name,
        id = run_id,
        save_code = False
    )


    print(f">>>> Training {algorithm} on {config.env_name} environment, on {config.device}")
    train_loop(config, msg=algorithm)
    wandb.finish()


if __name__ == "__main__":
    main()
